import json
import torch
from torch.utils.data import Dataset
from PIL import Image
import clip


class TextImageDataset(Dataset):
    def __init__(self, json_file, preprocess):
        with open(json_file, "r") as f:
            self.data = json.load(f)  # 加载JSON数据
        self.data = list(self.data.values())
        self.preprocess = preprocess

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 获取样本数据
        sample = self.data[idx]
        prompt = sample["prompts"]
        positive_path = sample["positive"]
        negative_path = sample["negative"]

        # 设备
        device = torch.device(f"cpu")

        if positive_path.endswith(".jpg") or positive_path.endswith(".png"):
            # 加载并预处理图像
            positive_img = Image.open(positive_path).convert("RGB")  # 确保图像是RGB模式
            positive_img = self.preprocess(positive_img)
        elif positive_path.endswith(".pt"):
            positive_img = torch.load(positive_path, map_location=device, weights_only=True)
            positive_img = positive_img.squeeze(0)
            positive_img = positive_img.detach()
        else:
            raise ValueError(f"Unsupported file type for {positive_path}")
        
        if negative_path.endswith(".jpg") or negative_path.endswith(".png"):
            negative_img = Image.open(negative_path).convert("RGB")
            negative_img = self.preprocess(negative_img)
        elif negative_path.endswith(".pt"):
            negative_img = torch.load(negative_path, map_location=device, weights_only=True)
            negative_img = negative_img.squeeze(0)
            negative_img = negative_img.detach()
        else:
            raise ValueError(f"Unsupported file type for {negative_path}")

        # 将文本提示通过CLIP的tokenize处理
        prompt = clip.tokenize(prompt, truncate=True)[0]

        return prompt, positive_img, negative_img
